Spark ML

Read training and test data. In this case test data is labeled as well (we will generate our label based on the arrdelay field)


In [15]:
training = sqlContext.read.parquet("data/training.parquet")
test = sqlContext.read.parquet("data/test.parquet")

In [16]:
test.printSchema()


root
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- dayofmonth: integer (nullable = true)
 |-- dayofweek: integer (nullable = true)
 |-- deptime: integer (nullable = true)
 |-- crsdeptime: integer (nullable = true)
 |-- arrtime: integer (nullable = true)
 |-- crsarrtime: integer (nullable = true)
 |-- actualelapsetime: integer (nullable = true)
 |-- crselapsetime: integer (nullable = true)
 |-- airtime: integer (nullable = true)
 |-- arrdelay: integer (nullable = true)
 |-- depdelay: integer (nullable = true)
 |-- distance: integer (nullable = true)
 |-- taxiin: integer (nullable = true)
 |-- taxiout: integer (nullable = true)
 |-- cancelled: integer (nullable = true)
 |-- diverted: integer (nullable = true)
 |-- carrierdelay: integer (nullable = true)
 |-- weatherdelay: integer (nullable = true)
 |-- nasdelay: integer (nullable = true)
 |-- securitydelay: integer (nullable = true)
 |-- lateaircraftdelay: integer (nullable = true)


In [18]:
test.first()


Out[18]:
Row(year=2006, month=2, dayofmonth=21, dayofweek=2, deptime=902, crsdeptime=905, arrtime=1027, crsarrtime=1030, actualelapsetime=205, crselapsetime=205, airtime=190, arrdelay=-3, depdelay=-3, distance=1162, taxiin=7, taxiout=8, cancelled=0, diverted=0, carrierdelay=0, weatherdelay=0, nasdelay=0, securitydelay=0, lateaircraftdelay=0)

Generate label column for the training data


In [19]:
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import udf

is_late = udf(lambda delay: 1.0 if delay > 0 else 0.0, DoubleType())
training = training.withColumn("is_late",is_late(training.arrdelay))

Create and fit Spark ML model


In [21]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline

# Create feature vectors. Ignore arr_delay and it's derivate, is_late
feature_assembler = VectorAssembler(
    inputCols=[x for x in training.columns if x not in ["is_late","arrdelay"]],
    outputCol="features")

reg = LogisticRegression().setParams(
    maxIter = 100,
    labelCol="is_late",
    predictionCol="prediction")

model = Pipeline(stages=[feature_assembler, reg]).fit(training)

In [20]:
[x for x in training.columns if x not in ["is_late","arrdelay"]]


Out[20]:
['year',
 'month',
 'dayofmonth',
 'dayofweek',
 'deptime',
 'crsdeptime',
 'arrtime',
 'crsarrtime',
 'actualelapsetime',
 'crselapsetime',
 'airtime',
 'depdelay',
 'distance',
 'taxiin',
 'taxiout',
 'cancelled',
 'diverted',
 'carrierdelay',
 'weatherdelay',
 'nasdelay',
 'securitydelay',
 'lateaircraftdelay']

In [ ]:

Predict whether the aircraft will be late


In [22]:
predicted = model.transform(test)

In [23]:
predicted.show()


+----+-----+----------+---------+-------+----------+-------+----------+----------------+-------------+-------+--------+--------+--------+------+-------+---------+--------+------------+------------+--------+-------------+-----------------+--------------------+--------------------+--------------------+----------+
|year|month|dayofmonth|dayofweek|deptime|crsdeptime|arrtime|crsarrtime|actualelapsetime|crselapsetime|airtime|arrdelay|depdelay|distance|taxiin|taxiout|cancelled|diverted|carrierdelay|weatherdelay|nasdelay|securitydelay|lateaircraftdelay|            features|       rawPrediction|         probability|prediction|
+----+-----+----------+---------+-------+----------+-------+----------+----------------+-------------+-------+--------+--------+--------+------+-------+---------+--------+------------+------------+--------+-------------+-----------------+--------------------+--------------------+--------------------+----------+
|2006|    2|        21|        2|    902|       905|   1027|      1030|             205|          205|    190|      -3|      -3|    1162|     7|      8|        0|       0|           0|           0|       0|            0|                0|[2006.0,2.0,21.0,...|[0.75595747136043...|[0.68047541434396...|       0.0|
|2005|    7|        26|        2|   2147|      2130|   2338|      2340|             111|          130|     98|      -2|      17|     737|     3|     10|        0|       0|           0|           0|       0|            0|                0|[2005.0,7.0,26.0,...|[0.10038430116296...|[0.52507502206005...|       0.0|
|2006|    2|        16|        4|   1038|      1025|   1346|      1353|             128|          148|    118|      -7|      13|    1038|     3|      7|        0|       0|           0|           0|       0|            0|                0|[2006.0,2.0,16.0,...|[0.59474584225762...|[0.64445331985803...|       0.0|
|2006|   11|         3|        5|   1131|      1130|   1545|      1550|             194|          200|    169|      -5|       1|    1444|     5|     20|        0|       0|           0|           0|       0|            0|                0|[2006.0,11.0,3.0,...|[0.33650329872936...|[0.58334088313174...|       0.0|
|2005|    2|        10|        4|   1438|      1435|   1846|      1805|             188|          150|    136|      41|       3|     946|     6|     46|        0|       0|           0|           0|      41|            0|                0|[2005.0,2.0,10.0,...|[-1.7967604665385...|[0.14224586889189...|       1.0|
|2005|    9|        22|        4|   1324|      1323|   1622|      1626|             178|          183|    158|      -4|       1|    1145|     6|     14|        0|       0|           0|           0|       0|            0|                0|[2005.0,9.0,22.0,...|[0.43430169698074...|[0.60690040475391...|       0.0|
|2006|    8|         1|        2|    624|       624|   1052|      1048|             208|          204|    190|       4|       0|    1501|     2|     16|        0|       0|           0|           0|       0|            0|                0|[2006.0,8.0,1.0,2...|[0.57520949895389...|[0.63996436880183...|       0.0|
|2006|    9|        19|        2|   1354|      1325|   1651|      1610|             117|          105|     86|      41|      29|     552|    17|     14|        0|       0|          29|           0|      12|            0|                0|[2006.0,9.0,19.0,...|[-0.8262033756838...|[0.30444844436252...|       1.0|
|2007|    5|         7|        1|   1345|      1345|   1633|      1623|             168|          158|    143|      10|       0|     930|     5|     20|        0|       0|           0|           0|       0|            0|                0|[2007.0,5.0,7.0,1...|[0.21298482279816...|[0.55304583268858...|       0.0|
|2008|   12|         3|        3|   1736|      1715|   1844|      1815|              68|           60|     47|      29|      21|     190|    14|      7|        0|       0|           0|           0|       8|            0|               21|[2008.0,12.0,3.0,...|[-0.2252103760178...|[0.44393417612962...|       1.0|
|2005|    3|         6|        7|   2042|      2030|   2119|      2105|              97|           95|     76|      14|      12|     575|     4|     17|        0|       0|           0|           0|       0|            0|                0|[2005.0,3.0,6.0,7...|[0.01849917836531...|[0.50462466270456...|       0.0|
|2003|    7|        18|        5|   1810|      1817|   1935|      2002|             145|          165|    130|     -27|      -7|     925|     6|      9|        0|       0|           0|           0|       0|            0|                0|[2003.0,7.0,18.0,...|[0.68704951087341...|[0.66531025466928...|       0.0|
|2005|   10|        10|        1|   1712|      1657|   1912|      1904|             240|          247|    212|       8|      15|    1557|    11|     17|        0|       0|           0|           0|       0|            0|                0|[2005.0,10.0,10.0...|[-0.0815396608142...|[0.47962637175193...|       1.0|
|2006|   10|         1|        7|   1448|      1448|   1548|      1600|              60|           72|     46|     -12|       0|     229|     2|     12|        0|       0|           0|           0|       0|            0|                0|[2006.0,10.0,1.0,...|[0.67406810508180...|[0.66241347664364...|       0.0|
|2004|    5|         2|        7|    810|       810|    931|       939|              81|           89|     68|      -8|       0|     500|     2|     11|        0|       0|           0|           0|       0|            0|                0|[2004.0,5.0,2.0,7...|[0.85305776081439...|[0.70120818394559...|       0.0|
|2008|    4|         4|        5|   1115|      1010|   1359|      1132|             164|           82|     70|     147|      65|     395|     4|     90|        0|       0|          65|           0|      82|            0|                0|[2008.0,4.0,4.0,5...|[-6.3209062697130...|[0.00179508488638...|       1.0|
|2005|    1|        16|        7|   1305|      1305|   1427|      1431|              82|           86|     65|      -4|       0|     489|     4|     13|        0|       0|           0|           0|       0|            0|                0|[2005.0,1.0,16.0,...|[0.58981194251394...|[0.64332199541140...|       0.0|
|2007|   12|         9|        7|   1634|      1635|   1943|      1928|             189|          173|    137|      15|      -1|    1028|    39|     13|        0|       0|           0|           0|      15|            0|                0|[2007.0,12.0,9.0,...|[0.02920324906575...|[0.50730029344818...|       0.0|
|2004|    9|        17|        5|   1758|      1800|   1921|      1932|              83|           92|     68|     -11|      -2|     334|     8|      7|        0|       0|           0|           0|       0|            0|                0|[2004.0,9.0,17.0,...|[0.69985325368104...|[0.66815523581717...|       0.0|
|2007|    2|        26|        1|   1654|      1445|   2041|      1823|             167|          158|    157|     138|     129|    1088|     3|      7|        0|       0|           5|           0|      23|            0|              110|[2007.0,2.0,26.0,...|[-4.2597873929270...|[0.01392856006835...|       1.0|
+----+-----+----------+---------+-------+----------+-------+----------+----------------+-------------+-------+--------+--------+--------+------+-------+---------+--------+------------+------------+--------+-------------+-----------------+--------------------+--------------------+--------------------+----------+
only showing top 20 rows


In [24]:
predicted.select("is_late", "prediction").show()


---------------------------------------------------------------------------
AnalysisException                         Traceback (most recent call last)
<ipython-input-24-eadcb64d2c30> in <module>()
----> 1 predicted.select("is_late", "prediction").show()

/opt/spark-1.6.1/python/pyspark/sql/dataframe.pyc in select(self, *cols)
    860         [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
    861         """
--> 862         jdf = self._jdf.select(self._jcols(*cols))
    863         return DataFrame(jdf, self.sql_ctx)
    864 

/opt/spark-1.6.1/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
    811         answer = self.gateway_client.send_command(command)
    812         return_value = get_return_value(
--> 813             answer, self.gateway_client, self.target_id, self.name)
    814 
    815         for temp_arg in temp_args:

/opt/spark-1.6.1/python/pyspark/sql/utils.pyc in deco(*a, **kw)
     49                                              e.java_exception.getStackTrace()))
     50             if s.startswith('org.apache.spark.sql.AnalysisException: '):
---> 51                 raise AnalysisException(s.split(': ', 1)[1], stackTrace)
     52             if s.startswith('java.lang.IllegalArgumentException: '):
     53                 raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)

AnalysisException: u"cannot resolve 'is_late' given input columns: [crselapsetime, dayofweek, taxiout, month, probability, carrierdelay, prediction, nasdelay, dayofmonth, lateaircraftdelay, rawPrediction, crsdeptime, airtime, year, securitydelay, cancelled, arrdelay, weatherdelay, actualelapsetime, arrtime, diverted, distance, features, depdelay, crsarrtime, deptime, taxiin];"

Check model performance


In [25]:
predicted = predicted.withColumn("is_late",is_late(predicted.arrdelay))
predicted.select("is_late", "prediction").show()


+-------+----------+
|is_late|prediction|
+-------+----------+
|    0.0|       0.0|
|    0.0|       0.0|
|    0.0|       0.0|
|    0.0|       0.0|
|    1.0|       1.0|
|    0.0|       0.0|
|    1.0|       0.0|
|    1.0|       1.0|
|    1.0|       0.0|
|    1.0|       1.0|
|    1.0|       0.0|
|    0.0|       0.0|
|    1.0|       1.0|
|    0.0|       0.0|
|    0.0|       0.0|
|    1.0|       1.0|
|    0.0|       0.0|
|    1.0|       0.0|
|    0.0|       0.0|
|    1.0|       1.0|
+-------+----------+
only showing top 20 rows


In [26]:
predicted.crosstab("is_late","prediction").show()


+------------------+----+----+
|is_late_prediction| 1.0| 0.0|
+------------------+----+----+
|               1.0|1448|1110|
|               0.0|  62|2805|
+------------------+----+----+


In [ ]: